"""For basic init and call"""
import os

from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.llms import QianfanLLMEndpoint
from langchain.sql_database import SQLDatabase

os.environ["QIANFAN_AK"] = "ZGgmCMM95MfabqYLG2swwRWM"
os.environ["QIANFAN_SK"] = "yDAIOpvEu7vm2phjYDT5sdIg5h4LBb4v"

llm = QianfanLLMEndpoint(streaming=True)
res = llm("你好，千帆")
print(res)

# 数据库
db_user = "root"
db_password = "root123"
db_host = "localhost:3306"
db_name = "group_insurance"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

from langchain_experimental.sql import SQLDatabaseChain


# SQLDatabaseChain 是一个简单的链，允许对数据库执行 SQL 查询。它需要一个 SQLDatabase 对象并顺序调用 sql_query 和 sql_print_result 等工具来运行和打印查询。

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("描述与用户或者客户相关的表及其关系")
# db_chain.run("列出订单数量最多的前3个客户名称")
